import inspect
import dash
import json
import jsondiff as jd
import dash_table as dt
import dash_html_components as html
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
from jsondiff import diff

from ..utils import force_list


def new(app, id_prefix, columns, **kwargs):

    # Read args
    table_id = f'{id_prefix}-table'
    control_id = f'{id_prefix}-action-button'
    on_init_fun = kwargs.get('on_init')
    inspect_link_getter = kwargs.get('inspect_link_getter')
    data_source_reader = kwargs.get('data_source_reader')
    data_source_store = kwargs.get('data_source_store')

    # Prepare table args
    dt_signature_args = inspect.signature(dt.DataTable).parameters
    dt_kwargs = {
        key: val
        for key, val in kwargs.items()
        if key in dt_signature_args
    }

    # Build template
    if callable(inspect_link_getter):
        columns.append({
            'name': 'ID',
            'id': 'id_link',
            'presentation': 'markdown'
        })

    content = dbc.Container([
        dbc.Row([
            dbc.Col([
                dt.DataTable(
                    id=table_id,
                    columns=columns,
                    data=kwargs.get('data', []),
                    **dt_kwargs
                )
            ], md=12)
        ]),
        dbc.Row([
            dbc.Col([
                html.Button(
                    '新增',
                    id=control_id,
                    n_clicks=0,
                    style={
                        'width': '100%',
                        'margin-top': '15px'
                    }),
            ], md=3),
        ])
    ])

    callback_args = [
        Output(table_id, 'data'),
        Output(control_id, 'children'),
        Input(control_id, 'n_clicks'),
        Input(table_id, 'data'),
        Input(table_id, 'data_previous'),
    ]

    if isinstance(data_source_store, str) and data_source_store:
        callback_args.append(Input(data_source_store, 'data'))

    callback_args.extend([
        State(table_id, 'columns'),
        State(control_id, 'children')
    ])

    @app.callback(*callback_args)
    def dispatch(_, rows, prev_rows, *args):
        # The last two are states
        cols, next_action = args[-2], args[-1]

        ctx = dash.callback_context
        prop_id = ctx.triggered[0]['prop_id'].split('.')[0]
        state = _infer_table_state(next_action)

        print(f'[Table] Event fired: by = {prop_id}. current state = {state}')

        # Possible initialization
        if not prop_id and callable(on_init_fun):
            rows = on_init_fun()

        # The event is caused by the data source update
        if prop_id == data_source_store:
            rows = json.loads(args[-3])
            if callable(data_source_reader):
                rows = data_source_reader(rows)

        # The event is fired by table control
        if prop_id == control_id:
            next_action, rows = _handle_do_next_action(next_action, rows, cols, **kwargs)

        # The event is fired by table (edit)
        if prop_id == table_id:
            rows = _handle_edit(state, rows, prev_rows, **kwargs)

        # populate instance links exist
        if list(filter(lambda col: col['id'] == 'id_link', cols)):
            for idx, row in enumerate(rows):
                if row.get('id'):
                    link = inspect_link_getter(row)
                    rows[idx]['id_link'] = f'[{row["id"]}]({link})'

        return rows, next_action

    @app.callback(
        Output(table_id, 'row_deletable'),
        Input(control_id, 'children'))
    def toggle_control(next_action):
        state = _infer_table_state(next_action)

        return state == 'idle'

    return content


def _handle_do_next_action(action, rows, cols, **kwargs):
    if action.lower() == '新增':
        rows.append({
            col['id']: ''
            for col in cols
        })
        return '确认', rows

    if action.lower() == '确认':
        next_action = '新增'

        """
        Some hacks here! When edition is confirmed, we want to trigger an instance creation
        event. But because the true `pre_row` (when there is one less  row) is long gone, the
        change detected by `_infer_table_change` will be `update`, if we just chunk in the current
        `prev_row`.

        So! let's trick the `_infer_table_change` by giving it rows with the latest row removed.
        """
        prev_rows = rows[:-1]
        _handle_edit('idle', rows, prev_rows, **kwargs)
        return next_action, rows


def _handle_edit(state, rows, prev_rows, **kwargs):
    on_update_fun = kwargs.get('on_update')
    on_insert_fun = kwargs.get('on_insert')
    on_delete_fun = kwargs.get('on_delete')

    action, affect_rows, affect_rows_idx = _infer_table_change(rows, prev_rows)
    print(f'Table edit action is {action}.')

    if action == 'delete' and callable(on_delete_fun):
        for data in affect_rows:
            on_delete_fun(state, data, rows)

    if action == 'insert' and callable(on_insert_fun):
        for idx, data in zip(affect_rows_idx, affect_rows):
            rows[idx] = on_insert_fun(state, data, rows)

    if action == 'update' and callable(on_update_fun):
        for idx in affect_rows_idx:
            rows[idx] = on_update_fun(state, rows[idx], rows)

    return rows


def _infer_table_state(next_action):
    if next_action.lower() == '新增':
        return 'idle'

    if next_action.lower() == '确认':
        return 'editing'


def _infer_table_change(rows, prev_rows):
    summary = diff(prev_rows, rows)

    if summary is None:
        return 'none', None, None

    # Detect edgy cases
    # See https://github.com/xlwings/jsondiff/issues/10
    if isinstance(summary, list):
        if summary == []:
            # [1] Delete the last row
            return 'delete', prev_rows, list(range(len(prev_rows)))
        else:
            # [2] Adds the first row
            return 'insert', rows, list(range(len(rows)))

    # Detect deletions
    deleted_rows_idx = summary.get(jd.delete)
    if deleted_rows_idx:
        deleted_rows = list(map(prev_rows.__getitem__, deleted_rows_idx))
        return 'delete', deleted_rows, deleted_rows_idx

    # Detect additions
    added_rows_desp = summary.get(jd.insert)
    if added_rows_desp:
        added_rows = [x[1] for x in added_rows_desp]
        added_rows_idx = [x[0] for x in added_rows_desp]
        return 'insert', added_rows, added_rows_idx

    # Detect updates
    updated_rows_idx = list(summary.keys())
    updated_rows = list(map(prev_rows.__getitem__, updated_rows_idx))
    return 'update', updated_rows, updated_rows_idx
